{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "from fastai.vision import *\n",
    "import torch\n",
    "from torchsummary import summary\n",
    "torch.cuda.set_device(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.IMAGENETTE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "tfms = get_transforms(do_flip=False)\n",
    "data = ImageDataBunch.from_folder(path, train = 'train', valid = 'val', bs = batch_size, size = 224, ds_tfms = tfms).normalize(imagenet_stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model on GPU\n"
     ]
    }
   ],
   "source": [
    "learn = cnn_learner(data, models.resnet34, metrics = accuracy)\n",
    "learn = learn.load('unfreeze_imagenet_bs64')\n",
    "learn.freeze()\n",
    "# learn.summary()\n",
    "\n",
    "class Flatten(nn.Module) :\n",
    "    def forward(self, input):\n",
    "        return input.view(input.size(0), -1)\n",
    "\n",
    "def conv2(ni, nf) : \n",
    "    return conv_layer(ni, nf, stride = 2)\n",
    "\n",
    "class ResBlock(nn.Module):\n",
    "    def __init__(self, nf):\n",
    "        super().__init__()\n",
    "        self.conv1 = conv_layer(nf,nf)\n",
    "        \n",
    "    def forward(self, x): \n",
    "        return (x + self.conv1(x))\n",
    "\n",
    "def conv_and_res(ni, nf): \n",
    "    return nn.Sequential(conv2(ni, nf), ResBlock(nf))\n",
    "\n",
    "def conv_(nf) : \n",
    "    return nn.Sequential(conv_layer(nf, nf), ResBlock(nf))\n",
    "    \n",
    "net = nn.Sequential(\n",
    "    conv_layer(3, 64, ks = 7, stride = 2, padding = 3),\n",
    "    nn.MaxPool2d(3, 2, padding = 1),\n",
    "    conv_(64),\n",
    "    conv_and_res(64, 128),\n",
    "    conv_and_res(128, 256),\n",
    "    AdaptiveConcatPool2d(),\n",
    "    Flatten(),\n",
    "    nn.Linear(2 * 256, 128),\n",
    "    nn.Linear(128, 10)\n",
    ")\n",
    "        \n",
    "# net = CNN()\n",
    "if torch.cuda.is_available() : \n",
    "    net = net.cuda()\n",
    "    print('Model on GPU')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv2d-1         [-1, 64, 112, 112]           9,408\n",
      "              ReLU-2         [-1, 64, 112, 112]               0\n",
      "       BatchNorm2d-3         [-1, 64, 112, 112]             128\n",
      "         MaxPool2d-4           [-1, 64, 56, 56]               0\n",
      "            Conv2d-5           [-1, 64, 56, 56]          36,864\n",
      "              ReLU-6           [-1, 64, 56, 56]               0\n",
      "       BatchNorm2d-7           [-1, 64, 56, 56]             128\n",
      "            Conv2d-8           [-1, 64, 56, 56]          36,864\n",
      "              ReLU-9           [-1, 64, 56, 56]               0\n",
      "      BatchNorm2d-10           [-1, 64, 56, 56]             128\n",
      "         ResBlock-11           [-1, 64, 56, 56]               0\n",
      "           Conv2d-12          [-1, 128, 28, 28]          73,728\n",
      "             ReLU-13          [-1, 128, 28, 28]               0\n",
      "      BatchNorm2d-14          [-1, 128, 28, 28]             256\n",
      "           Conv2d-15          [-1, 128, 28, 28]         147,456\n",
      "             ReLU-16          [-1, 128, 28, 28]               0\n",
      "      BatchNorm2d-17          [-1, 128, 28, 28]             256\n",
      "         ResBlock-18          [-1, 128, 28, 28]               0\n",
      "           Conv2d-19          [-1, 256, 14, 14]         294,912\n",
      "             ReLU-20          [-1, 256, 14, 14]               0\n",
      "      BatchNorm2d-21          [-1, 256, 14, 14]             512\n",
      "           Conv2d-22          [-1, 256, 14, 14]         589,824\n",
      "             ReLU-23          [-1, 256, 14, 14]               0\n",
      "      BatchNorm2d-24          [-1, 256, 14, 14]             512\n",
      "         ResBlock-25          [-1, 256, 14, 14]               0\n",
      "AdaptiveMaxPool2d-26            [-1, 256, 1, 1]               0\n",
      "AdaptiveAvgPool2d-27            [-1, 256, 1, 1]               0\n",
      "AdaptiveConcatPool2d-28            [-1, 512, 1, 1]               0\n",
      "          Flatten-29                  [-1, 512]               0\n",
      "           Linear-30                  [-1, 128]          65,664\n",
      "           Linear-31                   [-1, 10]           1,290\n",
      "================================================================\n",
      "Total params: 1,257,930\n",
      "Trainable params: 1,257,930\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.57\n",
      "Forward/backward pass size (MB): 38.68\n",
      "Params size (MB): 4.80\n",
      "Estimated Total Size (MB): 44.05\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "# x, y = next(iter(data.train_dl))\n",
    "# net(torch.autograd.Variable(x).cuda())\n",
    "summary(net, (3, 224, 224))\n",
    "# print(learn.summary())\n",
    "# net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SaveFeatures :\n",
    "    def __init__(self, m) : \n",
    "        self.handle = m.register_forward_hook(self.hook_fn)\n",
    "    def hook_fn(self, m, inp, outp) : \n",
    "        self.features = outp\n",
    "    def remove(self) :\n",
    "        self.handle.remove()\n",
    "        \n",
    "# saving outputs of all Basic Blocks\n",
    "mdl = learn.model\n",
    "sf = [SaveFeatures(m) for m in [mdl[0][2], mdl[0][4], mdl[0][5], mdl[0][6]]]\n",
    "sf2 = [SaveFeatures(m) for m in [net[0], net[2], net[3], net[4]]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, y = next(iter(data.train_dl))\n",
    "x = torch.autograd.Variable(x).cuda()\n",
    "out1 = mdl(x)\n",
    "out2 = net(x)\n",
    "for i in range(4) : \n",
    "    assert(sf[i].features.shape == sf2[i].features.shape)\n",
    "del x, y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Training using teacher model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(net.parameters(), lr = 1e-4)\n",
    "num_epochs = 100\n",
    "total_step = len(data.train_ds) // batch_size\n",
    "train_loss_list = list()\n",
    "val_loss_list = list()\n",
    "min_val = 100\n",
    "for epoch in range(num_epochs):\n",
    "    trn = []\n",
    "    net.train()\n",
    "    for i, (images, labels) in enumerate(data.train_dl) :\n",
    "        loss = 0.0\n",
    "        if torch.cuda.is_available():\n",
    "            images = torch.autograd.Variable(images).cuda().float()\n",
    "            labels = torch.autograd.Variable(labels).cuda()\n",
    "        else : \n",
    "            images = torch.autograd.Variable(images).float()\n",
    "            labels = torch.autograd.Variable(labels)\n",
    "\n",
    "        y_pred = net(images)\n",
    "        y_pred2 = mdl(images)\n",
    "        \n",
    "        for k in range(5) : \n",
    "            loss += F.mse_loss(sf[k].features, sf2[k].features)\n",
    "        \n",
    "        loss += F.cross_entropy(y_pred, labels)\n",
    "        trn.append(loss.item() / 6)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "#         torch.nn.utils.clip_grad_value_(net.parameters(), 10)\n",
    "        optimizer.step()\n",
    "\n",
    "        if i % 200 == 199 :\n",
    "            print('epoch = ', epoch, ' step = ', i + 1, ' of total steps ', total_step, ' loss = ', loss.item() / 6)\n",
    "            \n",
    "    train_loss = (sum(trn) / len(trn))\n",
    "    train_loss_list.append(train_loss)\n",
    "    \n",
    "    net.eval()\n",
    "    val = []\n",
    "    with torch.no_grad() :\n",
    "        for i, (images, labels) in enumerate(data.valid_dl) :\n",
    "            if torch.cuda.is_available():\n",
    "                images = torch.autograd.Variable(images).cuda().float()\n",
    "                labels = torch.autograd.Variable(labels).cuda()\n",
    "            else : \n",
    "                images = torch.autograd.Variable(images).float()\n",
    "                labels = torch.autograd.Variable(labels)\n",
    "\n",
    "            # Forward pass\n",
    "            outputs = net(images)\n",
    "            loss = F.cross_entropy(outputs, labels)\n",
    "            val.append(loss)\n",
    "\n",
    "    val_loss = (sum(val) / len(val)).item()\n",
    "    val_loss_list.append(val_loss)\n",
    "    print('epoch : ', epoch + 1, ' / ', num_epochs, ' | TL : ', train_loss, ' | VL : ', val_loss)\n",
    "    \n",
    "    if val_loss < min_val :\n",
    "        print('saving model')\n",
    "        min_val = val_loss\n",
    "        torch.save(net.state_dict(), '../saved_models/model0.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# learn = Learner(data, net, metrics = accuracy)\n",
    "net.cpu()\n",
    "net.load_state_dict(torch.load('../saved_models/small_no_teacher/model0.pt', map_location = 'cpu'))\n",
    "net = net.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(range(38), train_loss_list, range(38), val_loss_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.108"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def _get_accuracy(dataloader, Net):\n",
    "    total = 0\n",
    "    correct = 0\n",
    "    Net.eval()\n",
    "    for i, (images, labels) in enumerate(dataloader):\n",
    "        images = torch.autograd.Variable(images).float()\n",
    "        labels = torch.autograd.Variable(labels).float()\n",
    "        \n",
    "        if torch.cuda.is_available() : \n",
    "            images = images.cuda()\n",
    "            labels = labels.cuda()\n",
    "\n",
    "        outputs = Net.forward(images)\n",
    "        outputs = F.log_softmax(outputs, dim = 1)\n",
    "\n",
    "        _, pred_ind = torch.max(outputs, 1)\n",
    "        \n",
    "        # converting to numpy arrays\n",
    "        labels = labels.data.cpu().numpy()\n",
    "        pred_ind = pred_ind.data.cpu().numpy()\n",
    "        \n",
    "        # get difference\n",
    "        diff_ind = labels - pred_ind\n",
    "        # correctly classified will be 1 and will get added\n",
    "        # incorrectly classified will be 0\n",
    "        correct += np.count_nonzero(diff_ind == 0)\n",
    "        total += len(diff_ind)\n",
    "\n",
    "    accuracy = correct / total\n",
    "    # print(len(diff_ind))\n",
    "    return accuracy\n",
    "\n",
    "_get_accuracy(data.valid_dl, net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(\n",
    "    nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),\n",
    "    nn.MaxPool2d(kernel_size = 2, stride = 2),\n",
    "    nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),\n",
    "    nn.MaxPool2d(kernel_size = 2, stride = 2),\n",
    "    nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, stride = 1, padding = 1),\n",
    "    nn.MaxPool2d(kernel_size = 2, stride = 2),\n",
    "    nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 3, stride = 1, padding = 1),\n",
    "    nn.MaxPool2d(kernel_size = 2, stride = 2),\n",
    "    nn.Conv2d(in_channels = 256, out_channels = 512, kernel_size = 3, stride = 1, padding = 1),\n",
    "    nn.MaxPool2d(kernel_size = 2, stride = 2),\n",
    "    Flatten(),\n",
    "    nn.Linear(512 * 7 * 7, 512),\n",
    "    nn.Linear(512, 10)\n",
    ")\n",
    "net = net.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Training without teacher model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(net.parameters(), lr = 1e-4)\n",
    "\n",
    "num_epochs = 100\n",
    "total_step = len(data.train_ds) // batch_size\n",
    "train_loss_list = list()\n",
    "val_loss_list = list()\n",
    "min_val = 100\n",
    "for epoch in range(num_epochs):\n",
    "    trn = []\n",
    "    net.train()\n",
    "    for i, (images, labels) in enumerate(data.train_dl) :\n",
    "        if torch.cuda.is_available():\n",
    "            images = torch.autograd.Variable(images).cuda().float()\n",
    "            labels = torch.autograd.Variable(labels).cuda()\n",
    "        else : \n",
    "            images = torch.autograd.Variable(images).float()\n",
    "            labels = torch.autograd.Variable(labels)\n",
    "\n",
    "        y_pred = net(images)\n",
    "        \n",
    "        loss = F.cross_entropy(y_pred, labels)\n",
    "        trn.append(loss.item())\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "#         torch.nn.utils.clip_grad_value_(net.parameters(), 10)\n",
    "        optimizer.step()\n",
    "\n",
    "        if i % 50 == 49 :\n",
    "            print('epoch = ', epoch, ' step = ', i + 1, ' of total steps ', total_step, ' loss = ', loss.item())\n",
    "            \n",
    "    train_loss = (sum(trn) / len(trn))\n",
    "    train_loss_list.append(train_loss)\n",
    "    \n",
    "    net.eval()\n",
    "    val = []\n",
    "    with torch.no_grad() :\n",
    "        for i, (images, labels) in enumerate(data.valid_dl) :\n",
    "            if torch.cuda.is_available():\n",
    "                images = torch.autograd.Variable(images).cuda().float()\n",
    "                labels = torch.autograd.Variable(labels).cuda()\n",
    "            else : \n",
    "                images = torch.autograd.Variable(images).float()\n",
    "                labels = torch.autograd.Variable(labels)\n",
    "\n",
    "            # Forward pass\n",
    "            outputs = net(images)\n",
    "            loss = F.cross_entropy(outputs, labels)\n",
    "            val.append(loss)\n",
    "\n",
    "    val_loss = (sum(val) / len(val)).item()\n",
    "    val_loss_list.append(val_loss)\n",
    "    val_acc = _get_accuracy(data.valid_dl, net)\n",
    "    print('epoch : ', epoch + 1, ' / ', num_epochs, ' | TL : ', train_loss, ' | VL : ', val_loss, ' | VA : ', val_acc * 100)\n",
    "    \n",
    "    if val_loss < min_val :\n",
    "        print('saving model')\n",
    "        min_val = val_loss\n",
    "        torch.save(net.state_dict(), '../saved_models/model1_normal.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net.load_state_dict(torch.load('../saved_models/model4_normal.pt'))\n",
    "_get_accuracy(data.valid_dl, net)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _evaluate(folder_name, net) :\n",
    "    acc_list = list()\n",
    "    for i in range(5) : \n",
    "        filename = '../saved_models/' + folder_name + '/model' + str(i) + '.pt'\n",
    "        net.load_state_dict(torch.load(filename))\n",
    "        acc_list.append(_get_accuracy(data.valid_dl, net))\n",
    "\n",
    "    acc_list = [i * 100 for i in acc_list]\n",
    "    print(acc_list)\n",
    "    print(np.mean(acc_list))\n",
    "    print(np.std(acc_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[91.8, 92.0, 91.8, 89.8, 91.0]\n",
      "91.28\n",
      "0.8158431221748459\n"
     ]
    }
   ],
   "source": [
    "_evaluate('small_classifier', net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (ak_fastai)",
   "language": "python",
   "name": "ak_fastai"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
